package com.lauor.smpedr.core.executor;

import com.lauor.smpedr.Configuration;
import com.lauor.smpedr.MethodSignature;
import com.lauor.smpedr.core.handler.ResultSetHandler;
import com.lauor.smpedr.core.helper.EdrHelper;
import com.lauor.smpedr.core.helper.LogSqlHelper;
import com.lauor.smpedr.param.SqlArgMap;
import com.lauor.smpedr.transaction.Transaction;
import com.lauor.smpedr.utils.Str;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class BaseExecutor implements Executor {
    private final static Logger LOG = LoggerFactory.getLogger(BaseExecutor.class);

    private Transaction transaction;

    private ResultSetHandler resultSetHandler;

    private Configuration configuration;

    public BaseExecutor(Transaction transaction, ResultSetHandler resultSetHandler, Configuration configuration) {
        this.transaction = transaction;
        this.resultSetHandler = resultSetHandler;
        this.configuration = configuration;
    }

    public <E> List<E> query(String sql, SqlArgMap sqlArgs, Class<E> clazz, MethodSignature invokerSignature) throws Exception {
        //根据sql中参数顺序获取参数值
        List paramList = EdrHelper.getParamsSortedBySql(sql, sqlArgs);
        //处理sql语句
        sql = EdrHelper.replaceSqlPlaceHolder(sql);

        int resultRows = 0;
        try (PreparedStatement preparedStatement = transaction.getConnection().prepareStatement(sql)){
            this.setTimeOut(preparedStatement, transaction);
            for (int i = 0; i < paramList.size(); i++) {
                preparedStatement.setObject(i + 1, paramList.get(i));
            }
            try (ResultSet resultSet = preparedStatement.executeQuery()){
                List<E> rsList = this.handleResultSet(resultSet, clazz, invokerSignature);
                resultRows = rsList.size();
                return rsList;
            }
        } finally {
            if ( LOG.isDebugEnabled() ){
                LOG.debug( LogSqlHelper.buildSqlLogStr(sql, paramList, resultRows) );
            }
        }
    }

    private <E> List<E> handleResultSet(ResultSet resultSet, Class<E> clazz, MethodSignature invokerSignature) throws Exception {
        if ( invokerSignature.isReturnsList() ){
            return resultSetHandler.<E>handleResultSetsForMany(resultSet, clazz);
        } else if ( invokerSignature.isReturnsVoid() ){
            return Collections.EMPTY_LIST;
        } else {//返回单个数据
            E result = resultSetHandler.handleResultSetsForSingle(resultSet, clazz);
            return result == null ? Collections.EMPTY_LIST : Arrays.asList(result);
        }
    }

    public int update(String sql, Object obj, SqlArgMap sqlArgs) throws Exception {
        //通过where 切割sql
        String[] sqlPartArr = sql.split(" where ");
        //数据值+参数值
        List paramList = Collections.EMPTY_LIST, valList = Collections.EMPTY_LIST;

        List<List> tempValList = EdrHelper.getInsertParamsSortedBySql(sqlPartArr[0], Arrays.asList(obj));
        if ( !tempValList.isEmpty() ){
            valList = tempValList.get(0);
        }
        //参数值
        String sqlWherePart = EdrHelper.getSqlWherePart(sql);
        if ( !Str.isNull(sqlWherePart) ){
            paramList = EdrHelper.getParamsSortedBySql(sqlWherePart, sqlArgs);
        }
        //组合参数
        valList.addAll(paramList);
        //处理sql语句
        sql = EdrHelper.replaceSqlPlaceHolder(sql);
        int affectRows = 0;
        try (PreparedStatement preparedStatement = transaction.getConnection().prepareStatement(sql)){
            this.setTimeOut(preparedStatement, transaction);
            for (int i = 0; i < valList.size(); i++) {
                preparedStatement.setObject(i + 1, valList.get(i));
            }
            affectRows = preparedStatement.executeUpdate();
            return affectRows;
        } finally {
            if ( LOG.isDebugEnabled() ){
                LOG.debug( LogSqlHelper.buildSqlLogStr(sql, valList, affectRows) );
            }
        }
    }

    public <E> int executeBatch(String sql, List<E> dataList, Class cls) throws Exception {
        List<List> paramList = EdrHelper.getInsertParamsSortedBySql(sql, dataList);
        sql = EdrHelper.replaceSqlPlaceHolder(sql);

        int affectRows = 0;
        try (PreparedStatement preparedStatement = transaction.getConnection().prepareStatement(sql, Statement.RETURN_GENERATED_KEYS)){
            this.setTimeOut(preparedStatement, transaction);
            for (List list : paramList) {
                int index = 0;
                for (Object param : list) {
                    preparedStatement.setObject(++index, param);
                }
                preparedStatement.addBatch();
            }
            int[] effects = preparedStatement.executeBatch();
            for (int rowRs : effects) {
                affectRows += rowRs;
            }
            try (ResultSet resultSet = preparedStatement.getGeneratedKeys()){
                resultSetHandler.handleGeneratedKeys(resultSet, dataList, cls);
            }
        } finally {
            if ( LOG.isDebugEnabled() ){
                LOG.debug( LogSqlHelper.buildSqlInsertLogStr(sql, paramList, affectRows) );
            }
        }
        return affectRows;
    }

    public int delete(String sql, SqlArgMap sqlArgs) throws Exception {
        //根据sql中参数顺序获取参数值
        List paramList = EdrHelper.getParamsSortedBySql(sql, sqlArgs);
        //处理sql语句
        sql = EdrHelper.replaceSqlPlaceHolder(sql);
        int affectRows = 0;
        try (PreparedStatement preparedStatement = transaction.getConnection().prepareStatement(sql)){
            this.setTimeOut(preparedStatement, transaction);
            for (int i = 0; i < paramList.size(); i++) {
                preparedStatement.setObject(i + 1, paramList.get(i));
            }
            affectRows = preparedStatement.executeUpdate();
            return affectRows;
        } finally {
            if ( LOG.isDebugEnabled() ){
                LOG.debug( LogSqlHelper.buildSqlLogStr(sql, paramList, affectRows) );
            }
        }
    }

    @Override
    public void commit(boolean force) throws Exception {
        if (force){
            this.transaction.commit();
        }
    }

    public void commit() throws Exception {
        this.commit(false);
    }

    @Override
    public void rollback(boolean force) throws Exception {
        if (force){
            this.transaction.rollback();
        }
    }

    public void rollback() throws Exception {
        this.rollback(false);
    }

    public void close(boolean forceRollback) throws Exception {
        try {
            this.rollback(forceRollback);
        } finally {
            transaction.getConnection().close();
        }
    }

    @Override
    public boolean isClosed() {
        return this.transaction.isClosed();
    }

    @Override
    public Configuration getConfiguration() {
        return this.configuration;
    }

    private void setTimeOut(Statement statement, Transaction transaction) throws Exception {
        Integer timeOut = transaction.getTimeout();
        if (timeOut != null){
            statement.setQueryTimeout(timeOut);
        }
    }
}